iT邦幫忙

第 11 屆 iThome 鐵人賽

DAY 13
1

Estimator API

主旨:了解Estimator API並實際操作

在了解Estimator API的部分,我們將會學到:

  • 用簡單的方式建造可量產的ML模型
  • 在巨大到無法全部放進記憶體的資料上做訓練
  • 在TensorBoard上面監測訓練衡量值

Estimator API

Estimator API是TensorFlow API中最高的階級(參考Day11中TensorFlow API 的階層),使用Estimator API有下列優點:

  • 快速的建構模型
  • 產生checkpoint檔案(*.ckpt檔案),用於暫停或繼續訓練
  • 訓練在超出記憶體的資料量上
  • 訓練、評估、監測都可以輕鬆做到
  • 分散式訓練
  • 超參數的搜索、調整
  • 提供生產用模型

下面我們用例子來說明Estimator API的用法。

預製的estimators

如果今天要建構一個線性迴歸的模型,除了從頭開始撰寫之外,使用 tf.estimator.LinearRegressor() 可以更快速的建構出來,如下面程式碼:

宣告好之後就是訓練和預測:

以預測房地產價格為例,完整個程式碼可寫成:

這樣短短幾行,我們就完成了線性迴歸模型的建置,當然若要建構更複雜如DNN的模型也可以改成使用 tf.estimator.DNNRegressor() ,如下圖:

檢查點checkpoint的使用

使用checkpoint大致上可分為三種情況:

  1. 繼續之前的模型訓練
  2. 失敗後的恢復
  3. 使用訓練好的模型來做預測

checkpoint使用方式如下圖,在宣告estimator內加入要存放checkpoint的路徑便可。

透過這樣的方式,在訓練的時候就會把checkpoint存下來,若該路徑本來就有checkpoint的話,就會該checkpoint繼續開始訓練。

使用記憶體內的資料來訓練

這邊以一般常用的numpy和pandas資料類型舉例,只要使用 tf.estimator.inputs.numpy_input_fn()tf.estimator.inputs.pandas_input_fn() 將資料x,y定義好,訓練的一些參數如批次大小、epoch數、要不要隨機洗亂資料...等等,就可以供後面模型訓練的輸入做使用。

全部模型訓練的程式大概如下面所示, model.train(train_input_fn(XXX)) 這裡就是我們給入的記憶體內資料:

接著我們用實作來更詳細的看看上面提到的部分。

[GCP Lab實作-7]:在TensorFlow中使用Estimator API撰寫ML模型

這個實作中,我們將學會:

  • 使用tf.estimator建造ML模型和評估表現

[Part 1]:開始使用TensorFlow Estimator API

  1. 登入GCP,開啟Notebooks後,複製課程 Github repo (如Day9的Part 1 & 2步驟)。

  2. 在左邊的資料夾結構,點進 training-data-analyst > courses > machine_learning > deepdive > 03_tensorflow,然後打開檔案 b_estimator.ipynb

  3. 首先先將資料import進來,這邊的資料只有一部分之前用到的紐約計程車資料(7700筆)。

  1. 接著定義 train 和 eval 的 input function,在train的時候我們會隨機打亂資料,所以 shuffle = True ,但是 eval 就不需要所以是 shuffle = False

  1. prediction 的 input function 定義如下,因為是預測y,所以這邊y不需要再給值,所以 y = None

  1. 再來定義feature的行:

  1. 接著就可以把線性迴歸的模型寫出來如下,並開始訓練:

  1. 訓練好模型後,來看看 eval 的 RMSE 表現如何,大約是10.44:

  1. 再來看看 prediction 的 RMSE 表現,大概是11.79:

  1. 如果大家還有印象,在Day10的lab中,Valid RMSE約是9.35,Test RMSE約是5.44,都比目前我們的線性迴歸模型要好太多了,或許是因為我們沒有用全部的資料而只用了一部分的原因?那麼我們就使用全部的資料來看看Test RMSE是多少,發現有下降一些約是9.47:

在這個Lab我們雖然訓練了一個線性迴歸模型,但表現卻比不上用簡單的經驗和直覺計算出來的結果,別擔心,後面的章節和實作將會介紹到怎麼讓ML模型的表現更好!


今天介紹了Estimator API,明天我們將介紹到 “如何在巨大的資料集上做訓練”。


上一篇
鐵人賽Day12 - Intro to TensorFlow (2/6)
下一篇
鐵人賽Day14 - Intro to TensorFlow (4/6)
系列文
Machine Learning with TensorFlow on Google Cloud Platform30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言